<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Adding BetterTransformer support for new architectures

You want to add a new model for `Better Transformer`, the fast path of PyTorch Transformer API? Check this guideline!

## Models that should be supported

In theory, any model that has a transformer encoder layer, similar to the classic encoder described in the ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) paper should be supported. 
More specifically, a model that has an encoder block with a MultiHead-Attention module (with pre or post-attention layer norm) should be convertible to its `BetterTransformer` equivalent. The conditions can be summarized as follows:

- Use classic Multi Head attention module (for example, [DeBERTa](https://arxiv.org/abs/2006.03654) cannot be supported)
- Use either `gelu` or `relu` activation function   
- Have an even number of attention heads
- Do not use any attention bias (for eg `T5` uses attention bias, therefore cannot be supported)
- `eps` must be equal between the first and second layer norms for each layer

## How to convert a model into its `BetterTransformer` format?

### Step 1: Identifying the source layer to change

First, go to `optimum/bettertransformer/__init__.py` and you'll see the dictionary `BetterTransformerManager.MODEL_MAPPING`. This should contain the mapping between a model type, and the `Tuple[str, BetterTransformerBaseLayer]` composed of the name of the `nn.Module` that can be converted to its `BetterTransformer` equivalent, and effectively the equivalent `BetterTransformer` layer class.

Let us try to do it step by step for `Bert`, first we need to identify the layers that needs to be replaced:
```python
>>> from transformers import AutoModel

>>> model = AutoModel.from_pretrained("bert-base-uncased")
>>> print(model)  # doctest: +IGNORE_RESULT
...
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (11): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)
```

You can clearly see that the layers that need to be replaced are the `BertLayer` modules since they contain the whole encoder layer module.

### Step 2: Building the `xxxLayerBetterTransformer` module

Check that the identified module is not already copied from another module (by inspecting the source code in [`transformers`](https://github.com/huggingface/transformers) and checking that the class definition does not start with `# Copied from ...`) - and if not, create a class in `bettertransformer/models/encoder_model.py`.
Start with those lines:
```python
import torch
import torch.nn as nn

from ..base import BetterTransformerBaseLayer


class BertLayerBetterTransformer(BetterTransformerBaseLayer):
    def __init__(self, bert_layer, config):
...
```

Now, make sure to fill all the necessary attributes, the list of attributes are:

- `in_proj_weight`
- `in_proj_bias`
- `out_proj_weight`
- `out_proj_bias`
- `linear1_weight`
- `linear1_bias`
- `linear2_weight`
- `linear2_bias`
- `norm1_eps`
- `norm1_weight`
- `norm1_bias`
- `norm2_weight`
- `norm2_bias` 
- `num_heads`
- `embed_dim`

Note that these attributes correspond to all the components that are necessary to run a Transformer Encoder module, check the figure 1 on the ["Attention Is All You Need"](https://arxiv.org/pdf/1706.03762.pdf) paper.

Once you filled all these attributes (sometimes the `query`, `key` and `value` layers needs to be "contigufied", check the [`modeling_encoder.py`](https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/encoder_models.py) file to understand more.)

Make sure also to add the lines:
```python
self.is_last_layer = False
self.validate_bettertransformer()
```


### Step 3: Building the forward pass

First of all, start with the line `super().forward_checker()`, this is needed so that the parent class can run all the safety checkers before.

After the first forward pass, the hidden states needs to be *nested* using the attention mask. Once they are nested, the attention mask is not needed anymore, therefore can be set to `None`. This is how the forward pass is built for `Bert`, these lines should remain pretty much similar accross models, but sometimes the shapes of the attention masks are different across models. 
```python
super().forward_checker()

if hidden_states.is_nested:
    attention_mask = None

if attention_mask is not None:
    # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
    # 0->false->keep this token -inf->true->mask this token
    attention_mask = attention_mask.bool()
    attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
    seqlen = attention_mask.shape[1]
    lengths = torch.sum(~attention_mask, 1)
    if not all([l == seqlen for l in lengths]):
        hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
    attention_mask = None
```

Once the `hidden_states` are nested, call `torch._transformer_encoder_layer_fwd` using the right arguments as follows:
```python
hidden_states = torch._transformer_encoder_layer_fwd(
    hidden_states,
    self.embed_dim,
    self.num_heads,
    self.in_proj_weight,
    self.in_proj_bias,
    self.out_proj_weight,
    self.out_proj_bias,
    self.use_gelu,
    self.norm_first,
    self.norm1_eps,
    self.norm1_weight,
    self.norm1_bias,
    self.norm2_weight,
    self.norm2_bias,
    self.linear1_weight,
    self.linear1_bias,
    self.linear2_weight,
    self.linear2_bias,
    attention_mask,
)
```
At the last layer, it is important to "un-nest" the hidden_states so that it can be processed by the next modules, this is done in these lines:
```python
if hidden_states.is_nested and self.is_last_layer:
    hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)
```
Also make sure to return a `tuple` to follow the convention of `transformers`. 

The best way to reproduce this experiment on your own model is to try it by get some inspiration from the provided modeling scripts. Of course, we will be happy to help you converting your model if you open an issue or a Pull Request on `optimum`!

### Step 4: Sanity check!

As a last step, make sure to update the `BetterTransformerManager.MODEL_MAPPING` dictionary in `optimum/bettertransformer/__init__.py` with the correct names, and you should be ready to convert your model. For example, for Bert that would be:

```
MODEL_MAPPING = {
  ...
  "bert": ("BertLayer", BertLayerBetterTransformer),
  ...
}
```

Try it out with the conversion method that is presented in the [tutorials sections](../tutorials/convert)!
